class UnionFind:
    __slots__ = "p", "size"

    def __init__(self, n: int):
        self.p = list(range(n))
        self.size = [1] * n

    def find(self, x: int) -> int:
        if self.p[x] != x:
            self.p[x] = self.find(self.p[x])
        return self.p[x]

    def union(self, a: int, b: int) -> bool:
        pa, pb = self.find(a), self.find(b)
        if pa == pb:
            return False
        if self.size[pa] > self.size[pb]:
            self.p[pb] = pa
            self.size[pa] += self.size[pb]
        else:
            self.p[pa] = pb
            self.size[pb] += self.size[pa]
        return True

    def get_size(self, root: int) -> int:
        return self.size[root]


class Solution:
    def minMalwareSpread(self, graph: List[List[int]], initial: List[int]) -> int:
        n = len(graph)
        s = set(initial)
        uf = UnionFind(n)
        for i in range(n):
            if i not in s:
                for j in range(i + 1, n):
                    graph[i][j] and j not in s and uf.union(i, j)

        g = defaultdict(set)
        cnt = Counter()
        for i in initial:
            for j in range(n):
                if j not in s and graph[i][j]:
                    g[i].add(uf.find(j))
            for root in g[i]:
                cnt[root] += 1

        ans, mx = 0, -1
        for i in initial:
            t = sum(uf.get_size(root) for root in g[i] if cnt[root] == 1)
            if t > mx or (t == mx and i < ans):
                ans, mx = i, t
        return ans
